import os
import mindspore as ms

keep_layer_nums = list(range(2))
ckpt_path = "/home/zhangsenzhen/checkpoint_download/Ziya-LLaMA-13B-v1.1-hf/ziya/single/rank_0/transform.ckpt"
params = ms.load_checkpoint(ckpt_path)

new_params = []
for k, v in params.items():
    # if k.split('.')[0].isdigit():
    #     layer_num = int(k.split('.')[0])
    if 'layers' in k:
        layer_num = int(k.split('model.layers.')[1].split('.')[0])
        if layer_num in keep_layer_nums:
            new_params.append({'name': k, 'data': v})
    else:
        new_params.append({'name': k, 'data': v})
        
save_path = "/home/zhangsenzhen/checkpoint_download/Ziya-LLaMA-13B-v1.1-hf/ziya/single_2layer/rank_0/transform.ckpt"
save_dir = os.path.dirname(save_path)
os.makedirs(save_dir, exist_ok=True)
ms.save_checkpoint(new_params, save_path)